This tutorial illustrates the core visualization utilities available in Ax.
import numpy as np
from ax.service.ax_client import AxClient
from ax.modelbridge.cross_validation import cross_validate
from ax.plot.contour import interact_contour
from ax.plot.diagnostic import interact_cross_validation
from ax.plot.scatter import(
interact_fitted,
plot_objective_vs_constraints,
tile_fitted,
)
from ax.plot.slice import plot_slice
from ax.utils.measurement.synthetic_functions import hartmann6
from ax.utils.notebook.plotting import render, init_notebook_plotting
init_notebook_plotting()
[INFO 03-10 13:52:47] ax.utils.notebook.plotting: Injecting Plotly library into cell. Do not overwrite or delete cell.
The vizualizations require an experiment object and a model fit on the evaluated data. The routine below is a copy of the Service API tutorial, so the explanation here is omitted. Retrieving the experiment and model objects for each API paradigm is shown in the respective tutorials
noise_sd = 0.1
param_names = [f"x{i+1}" for i in range(6)] # x1, x2, ..., x6
def noisy_hartmann_evaluation_function(parameterization):
x = np.array([parameterization.get(p_name) for p_name in param_names])
noise1, noise2 = np.random.normal(0, noise_sd, 2)
return {
"hartmann6": (hartmann6(x) + noise1, noise_sd),
"l2norm": (np.sqrt((x ** 2).sum()) + noise2, noise_sd)
}
ax_client = AxClient()
ax_client.create_experiment(
name="test_visualizations",
parameters=[
{
"name": p_name,
"type": "range",
"bounds": [0.0, 1.0],
}
for p_name in param_names
],
objective_name="hartmann6",
minimize=True,
outcome_constraints=["l2norm <= 1.25"]
)
[INFO 03-10 13:52:47] ax.service.ax_client: Starting optimization with verbose logging. To disable logging, set the `verbose_logging` argument to `False`. Note that float values in the logs are rounded to 6 decimal points.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x1. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x2. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x3. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x4. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x5. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Inferred value type of ParameterType.FLOAT for parameter x6. If that is not the expected value type, you can explicity specify 'value_type' ('int', 'float', 'bool' or 'str') in parameter dict.
[INFO 03-10 13:52:47] ax.service.utils.instantiation: Created search space: SearchSpace(parameters=[RangeParameter(name='x1', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x2', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x3', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x4', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x5', parameter_type=FLOAT, range=[0.0, 1.0]), RangeParameter(name='x6', parameter_type=FLOAT, range=[0.0, 1.0])], parameter_constraints=[]).
[INFO 03-10 13:52:47] ax.modelbridge.dispatch_utils: Using Bayesian optimization since there are more ordered parameters than there are categories for the unordered categorical parameters.
[INFO 03-10 13:52:47] ax.modelbridge.dispatch_utils: Using Bayesian Optimization generation strategy: GenerationStrategy(name='Sobol+GPEI', steps=[Sobol for 12 trials, GPEI for subsequent trials]). Iterations after 12 will take longer to generate due to model-fitting.
for i in range(20):
parameters, trial_index = ax_client.get_next_trial()
# Local evaluation here can be replaced with deployment to external system.
ax_client.complete_trial(trial_index=trial_index, raw_data=noisy_hartmann_evaluation_function(parameters))
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 0 with parameters {'x1': 0.581043, 'x2': 0.055999, 'x3': 0.955743, 'x4': 0.617451, 'x5': 0.948157, 'x6': 0.693985}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 0 with data: {'hartmann6': (-0.133452, 0.1), 'l2norm': (1.641608, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 1 with parameters {'x1': 0.211094, 'x2': 0.741626, 'x3': 0.679923, 'x4': 0.502716, 'x5': 0.590746, 'x6': 0.119463}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 1 with data: {'hartmann6': (-1.247528, 0.1), 'l2norm': (1.233822, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 2 with parameters {'x1': 0.13418, 'x2': 0.893376, 'x3': 0.457086, 'x4': 0.533917, 'x5': 0.357445, 'x6': 0.719699}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 2 with data: {'hartmann6': (-0.420069, 0.1), 'l2norm': (1.428524, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 3 with parameters {'x1': 0.662665, 'x2': 0.572793, 'x3': 0.870751, 'x4': 0.700741, 'x5': 0.249633, 'x6': 0.153307}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 3 with data: {'hartmann6': (-0.195585, 0.1), 'l2norm': (1.205406, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 4 with parameters {'x1': 0.416389, 'x2': 0.768014, 'x3': 0.642841, 'x4': 0.140313, 'x5': 0.417642, 'x6': 0.326798}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 4 with data: {'hartmann6': (-0.460025, 0.1), 'l2norm': (1.182996, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 5 with parameters {'x1': 0.035836, 'x2': 0.690288, 'x3': 0.703162, 'x4': 0.778082, 'x5': 0.986822, 'x6': 0.940296}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 5 with data: {'hartmann6': (-0.071108, 0.1), 'l2norm': (1.944891, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 6 with parameters {'x1': 0.245147, 'x2': 0.198984, 'x3': 0.438857, 'x4': 0.291856, 'x5': 0.756427, 'x6': 0.826087}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 6 with data: {'hartmann6': (-0.352466, 0.1), 'l2norm': (1.342101, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 7 with parameters {'x1': 0.444366, 'x2': 0.975587, 'x3': 0.769879, 'x4': 0.452199, 'x5': 0.607928, 'x6': 0.109777}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 7 with data: {'hartmann6': (-2.305612, 0.1), 'l2norm': (1.487042, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 8 with parameters {'x1': 0.476339, 'x2': 0.326533, 'x3': 0.617167, 'x4': 0.316988, 'x5': 0.681075, 'x6': 0.392879}.
[INFO 03-10 13:52:47] ax.service.ax_client: Completed trial 8 with data: {'hartmann6': (-0.317739, 0.1), 'l2norm': (1.160549, 0.1)}.
[INFO 03-10 13:52:47] ax.service.ax_client: Generated new trial 9 with parameters {'x1': 0.100663, 'x2': 0.61659, 'x3': 0.742754, 'x4': 0.241038, 'x5': 0.098407, 'x6': 0.405969}.
[INFO 03-10 13:52:48] ax.service.ax_client: Completed trial 9 with data: {'hartmann6': (-0.480252, 0.1), 'l2norm': (1.039798, 0.1)}.
[INFO 03-10 13:52:48] ax.service.ax_client: Generated new trial 10 with parameters {'x1': 0.744745, 'x2': 0.089732, 'x3': 0.680274, 'x4': 0.245723, 'x5': 0.699597, 'x6': 0.184114}.
[INFO 03-10 13:52:48] ax.service.ax_client: Completed trial 10 with data: {'hartmann6': (-0.039595, 0.1), 'l2norm': (1.253578, 0.1)}.
[INFO 03-10 13:52:48] ax.service.ax_client: Generated new trial 11 with parameters {'x1': 0.857053, 'x2': 0.791758, 'x3': 0.538444, 'x4': 0.611314, 'x5': 0.254166, 'x6': 0.726559}.
[INFO 03-10 13:52:48] ax.service.ax_client: Completed trial 11 with data: {'hartmann6': (-0.158541, 0.1), 'l2norm': (1.620653, 0.1)}.
[INFO 03-10 13:53:16] ax.service.ax_client: Generated new trial 12 with parameters {'x1': 0.392713, 'x2': 0.751776, 'x3': 0.760579, 'x4': 0.415256, 'x5': 0.505964, 'x6': 0.178762}.
[INFO 03-10 13:53:16] ax.service.ax_client: Completed trial 12 with data: {'hartmann6': (-1.701544, 0.1), 'l2norm': (1.342492, 0.1)}.
[INFO 03-10 13:54:12] ax.service.ax_client: Generated new trial 13 with parameters {'x1': 0.330748, 'x2': 0.727469, 'x3': 0.70326, 'x4': 0.403077, 'x5': 0.396572, 'x6': 0.24499}.
[INFO 03-10 13:54:12] ax.service.ax_client: Completed trial 13 with data: {'hartmann6': (-0.977819, 0.1), 'l2norm': (1.216724, 0.1)}.
[INFO 03-10 13:55:12] ax.service.ax_client: Generated new trial 14 with parameters {'x1': 0.408331, 'x2': 0.751218, 'x3': 0.621685, 'x4': 0.464478, 'x5': 0.565027, 'x6': 0.071086}.
[INFO 03-10 13:55:12] ax.service.ax_client: Completed trial 14 with data: {'hartmann6': (-2.287696, 0.1), 'l2norm': (1.336879, 0.1)}.
[INFO 03-10 13:56:04] ax.service.ax_client: Generated new trial 15 with parameters {'x1': 0.39935, 'x2': 0.613381, 'x3': 0.591231, 'x4': 0.416722, 'x5': 0.556472, 'x6': 0.097024}.
[INFO 03-10 13:56:04] ax.service.ax_client: Completed trial 15 with data: {'hartmann6': (-1.441641, 0.1), 'l2norm': (1.093187, 0.1)}.
[INFO 03-10 13:56:48] ax.service.ax_client: Generated new trial 16 with parameters {'x1': 0.414644, 'x2': 0.763956, 'x3': 0.472371, 'x4': 0.454108, 'x5': 0.607982, 'x6': 0.073006}.
[INFO 03-10 13:56:48] ax.service.ax_client: Completed trial 16 with data: {'hartmann6': (-2.538693, 0.1), 'l2norm': (1.22708, 0.1)}.
[INFO 03-10 13:58:02] ax.service.ax_client: Generated new trial 17 with parameters {'x1': 0.424589, 'x2': 0.711907, 'x3': 0.475457, 'x4': 0.437215, 'x5': 0.610885, 'x6': 0.129027}.
[INFO 03-10 13:58:02] ax.service.ax_client: Completed trial 17 with data: {'hartmann6': (-1.847839, 0.1), 'l2norm': (1.158552, 0.1)}.
[INFO 03-10 13:58:52] ax.service.ax_client: Generated new trial 18 with parameters {'x1': 0.410265, 'x2': 0.813162, 'x3': 0.378261, 'x4': 0.416485, 'x5': 0.599477, 'x6': 0.019999}.
[INFO 03-10 13:58:52] ax.service.ax_client: Completed trial 18 with data: {'hartmann6': (-2.482171, 0.1), 'l2norm': (1.242579, 0.1)}.
[INFO 03-10 14:00:41] ax.service.ax_client: Generated new trial 19 with parameters {'x1': 0.42755, 'x2': 0.801235, 'x3': 0.430219, 'x4': 0.51233, 'x5': 0.617904, 'x6': 0.047163}.
[INFO 03-10 14:00:41] ax.service.ax_client: Completed trial 19 with data: {'hartmann6': (-2.917151, 0.1), 'l2norm': (1.274324, 0.1)}.
The plot below shows the response surface for hartmann6 metric as a function of the x1, x2 parameters.
The other parameters are fixed in the middle of their respective ranges, which in this example is 0.5 for all of them.
# this could alternately be done with `ax.plot.contour.plot_contour`
render(ax_client.get_contour_plot(param_x="x1", param_y="x2", metric_name='hartmann6'))
[INFO 03-10 14:00:42] ax.service.ax_client: Retrieving contour plot with parameter 'x1' on X-axis and 'x2' on Y-axis, for metric 'hartmann6'. Remaining parameters are affixed to the middle of their range.
The plot below allows toggling between different pairs of parameters to view the contours.
model = ax_client.generation_strategy.model
render(interact_contour(model=model, metric_name='hartmann6'))
This plot illustrates the tradeoffs achievable for 2 different metrics. The plot takes the x-axis metric as input (usually the objective) and allows toggling among all other metrics for the y-axis.
This is useful to get a sense of the pareto frontier (i.e. what is the best objective value achievable for different bounds on the constraint)
render(plot_objective_vs_constraints(model, 'hartmann6', rel=False))
CV plots are useful to check how well the model predictions calibrate against the actual measurements. If all points are close to the dashed line, then the model is a good predictor of the real data.
cv_results = cross_validate(model)
render(interact_cross_validation(cv_results))
Slice plots show the metric outcome as a function of one parameter while fixing the others. They serve a similar function as contour plots.
render(plot_slice(model, "x2", "hartmann6"))
Tile plots are useful for viewing the effect of each arm.
render(interact_fitted(model, rel=False))
Total runtime of script: 8 minutes, 14.27 seconds.